/*
 * Copyright (c) 2022. China Mobile (SuZhou) Software Technology Co.,Ltd. All rights reserved.
 * Lakehouse is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

package com.chinamobile.cmss.lakehouse.common.utils.spark

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.parser.ParserUtils.withOrigin
import org.apache.spark.sql.catalyst.parser.SqlBaseParser.{CreateTableHeaderContext, MultipartIdentifierContext, TableIdentifierContext}
import org.apache.spark.sql.catalyst.parser.{AbstractSqlParser, ParseException, SqlBaseParser}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.execution.SparkSqlAstBuilder
import org.apache.spark.sql.internal.VariableSubstitution

import scala.collection.mutable.ArrayBuffer

class SparkSqlParserExtend() extends AbstractSqlParser {
  val astBuilder = new SparkSqlAstBuilderExtend()

  private val substitutor = new VariableSubstitution

  protected override def parse[T](command: String)(toResult: SqlBaseParser => T): T = {
    super.parse(substitutor.substitute(command))(toResult)
  }

  def tables(sqlText: String, t: ArrayBuffer[TableIdentifierExtend]) = {
    TableHolder.tables.set(t)
    val res = try {
      parse(sqlText) { parser =>
        astBuilder.visitSingleStatement(parser.singleStatement()) match {
          case plan: LogicalPlan => plan
          case _ =>
            val position = Origin(None, None)
            throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position)
        }
      }
    } finally {
      TableHolder.tables.remove()
    }
    res
  }
}

/**
  * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier.
  */
class SparkSqlAstBuilderExtend() extends SparkSqlAstBuilder {
  override def visitTableIdentifier(ctx: TableIdentifierContext): TableIdentifier = {
    val ti = super.visitTableIdentifier(ctx)
    TableHolder.tables.get() += TableIdentifierExtend(ti.table, ti.database, None)
    ti
  }

  override def visitCreateTableHeader(ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) {
    val tableHeader = super.visitCreateTableHeader(ctx)

    TableHolder.tablesSet(tableHeader._1, None)
    tableHeader
  }

  override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] = {
    val multipartIdentifier = super.visitMultipartIdentifier(ctx)

    val ifInsert = ctx.parent.getChild(0).getText.toLowerCase match {
      case "insert" => Some("insert")
      case _ => None
    }

    TableHolder.tablesSet(multipartIdentifier, ifInsert)
    multipartIdentifier
  }
}

object TableHolder {
  val tables: ThreadLocal[ArrayBuffer[TableIdentifierExtend]] = new ThreadLocal[ArrayBuffer[TableIdentifierExtend]]

  def tablesSet(multipartIdentifier: Seq[String], oprationType: Option[String]) = multipartIdentifier.size match {
    case 2 => TableHolder.tables.get() += TableIdentifierExtend(multipartIdentifier(1), Option(multipartIdentifier.head), oprationType)
    case 1 => TableHolder.tables.get() += TableIdentifierExtend(multipartIdentifier.head, None, oprationType)
    case _ =>
  }
}

case class TableIdentifierExtend(table: String, database: Option[String], operator: Option[String]) {
  val identifier: String = table

  def this(table: String) = this(table, None, None)
}
